Skip to content

KV Cache for flow DiT#1875

Open
hcwang26 wants to merge 1 commit intoFunAudioLLM:mainfrom
hcwang26:main
Open

KV Cache for flow DiT#1875
hcwang26 wants to merge 1 commit intoFunAudioLLM:mainfrom
hcwang26:main

Conversation

@hcwang26
Copy link
Copy Markdown

Summary

Adds KV-cache / chunked streaming support for the flow DiT decoder (CosyVoice3), enabling incremental causal diffusion inference with reusable attention & conv caches. Fully backward-compatible with the existing CosyVoice2 UNet path.

Changes

  • cosyvoice/flow/DiT/modules.py: Custom einops-based apply_rotary_pos_emb with offset argument for chunk-aware RoPE. CausalConvPositionEmbedding.forward now accepts and updates a conv_cache. AttnProcessor.call accepts x_offset / att_cache, splits & concatenates past K/V, pads attn_mask on cache growth, and returns new_att_cache. New Attention.forward_chunk and DiTBlock.forward_chunk methods.
  • cosyvoice/flow/DiT/dit.py: InputEmbedding.forward threads conv_cache through causal conv. DiT.forward unchanged in behavior (returns (output, None) tuple for consistency). New DiT.forward_chunk(x, x_offset, mask, mu, t, spks, cond, ..., conv_cache=None, att_cache=None) performs offset-aware RoPE + per-block forward_chunk, returning (output, new_conv_cache, stacked_new_att_cache).
  • cosyvoice/flow/flow_matching.py: ConditionalCFM.forward_estimator normalized to always return a tuple. New solve_euler_chunk diffusion solver and forward_estimator_with_cache (handles both torch.nn.Module DiT and TRT-wrapped estimator with cache bindings — x_offset, conv_cache, att_cache as extra TRT inputs/outputs). CausalConditionalCFM.forward kept byte-identical to upstream (UNet path untouched); new CausalConditionalCFM.forward_chunk is the dedicated DiT/KV-cache entrypoint. compute_loss tolerant of tuple returns from DiT estimators.
  • cosyvoice/flow/flow.py: CausalMaskedDiffWithDiT.inference now calls self.decoder.forward_chunk(..., x_offset=0) for full inference. New CausalMaskedDiffWithDiT.inference_chunk(token, token_offset, ..., conv_cache, att_cache, ..., init_cache=False, chunk_size=25, n_timesteps=10) enables streaming decode with reusable caches and prompt-aware h_offset slicing.

Performance

  • ~1.5x to 2.0x faster inference on L20 GPU.

Compatibility

  • CausalConditionalCFM.forward signature & behavior match upstream exactly; CosyVoice2 inference is unaffected.
    DiT PyTorch module and TRT-exported engine both handled uniformly — TRT branch in forward_estimator_with_cache wires x_offset / conv_cache / att_cache as explicit TRT tensor bindings, so no hasattr/runtime-type dispatch is required at the CFM level.
  • Training (compute_loss) is safe against estimators that return (pred, cache) tuples.

Notes

  • Only CausalMaskedDiffWithDiT (CosyVoice3) uses the new chunked path; other flow variants remain unchanged.
    4 files touched, +434 / −42 lines total.

@gerayking
Copy link
Copy Markdown

nice pr

@aluminumbox
Copy link
Copy Markdown
Collaborator

感谢感谢,之前尝试cache遇到的问题是随着时长的增加,k/v cache占的显存显著增大,请问你这边有测比如40G显存能支持多少并发吗?

@gerayking
Copy link
Copy Markdown

感谢感谢,之前尝试cache遇到的问题是随着时长的增加,k/v cache占的显存显著增大,请问你这边有测比如40G显存能支持多少并发吗?

这个应该跟同时decode的音频长度跟正相关,可以加个联系方式吗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants